// Copyright 2019 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Dialect/HAL/IR/HALOps.h"

#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
#include "iree/compiler/Dialect/Util/IR/UtilTypes.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/FunctionImplementation.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"

namespace mlir::iree_compiler::IREE::HAL {

namespace {

// We aribtrarily say that unbounded dimensions in a torch program cannot
// exceed 53bits, making the maximum safe dimension 9007199254740991. The
// astute reader will note that this is also the maximum safe value in
// JavaScript, which also "happens" to be the largest mantissa value in a
// 64bit double. We need a maximum and in the absence of a better choice,
// with this one we are at least in good company. This limit is also used
// in the frontends.
static constexpr uint64_t MAX_DIM_VALUE = (static_cast<uint64_t>(1) << 53) - 1;

// Similarly we use a very conservative maximum rank value for specifying
// ranges of runtime rank resolution functions. Various frameworks have hard
// and practical limits ranging from 32 (numpy) to hundreds. At the time of
// writing, PyTorch throws weird errors if trying to print a tensor with a rank
// greater than 992. We really just want a smallish integer value to bound
// arithmetic, so we use an arbitrary maximum.
static constexpr uint64_t MAX_RANK_VALUE = 4096;

} // namespace

//===----------------------------------------------------------------------===//
// custom<DeviceQueueAffinityList>($devices, type($devices), $queue_affinities)
//===----------------------------------------------------------------------===//

static ParseResult parseDeviceQueueAffinityList(
    OpAsmParser &parser,
    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &devices,
    SmallVectorImpl<Type> &deviceTypes,
    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &queueAffinities) {
  if (failed(parser.parseLSquare())) {
    return failure();
  }
  do {
    OpAsmParser::UnresolvedOperand device;
    Type deviceType;
    OpAsmParser::UnresolvedOperand queueAffinity;
    Type queueAffinityType;
    if (failed(parser.parseLParen()) || failed(parser.parseOperand(device)) ||
        failed(parser.parseComma()) ||
        failed(parser.parseOperand(queueAffinity)) ||
        failed(parser.parseColon()) || failed(parser.parseType(deviceType)) ||
        failed(parser.parseComma()) ||
        failed(parser.parseType(queueAffinityType)) ||
        failed(parser.parseRParen())) {
      return failure();
    }
    devices.push_back(device);
    deviceTypes.push_back(deviceType);
    queueAffinities.push_back(queueAffinity);
  } while (succeeded(parser.parseOptionalComma()));
  if (failed(parser.parseRSquare())) {
    return failure();
  }
  return success();
}

static void printDeviceQueueAffinityList(OpAsmPrinter &p, Operation *,
                                         ValueRange devices,
                                         TypeRange deviceTypes,
                                         ValueRange queueAffinities) {
  p << "[";
  p.increaseIndent();
  p.printNewline();
  llvm::interleave(
      llvm::zip_equal(devices, deviceTypes, queueAffinities),
      [&](auto it) {
        auto [device, deviceType, queueAffinity] = it;
        p << "(";
        p.printOperand(device);
        p << ", ";
        p.printOperand(queueAffinity);
        p << " : ";
        p.printType(deviceType);
        p << ", ";
        p.printType(queueAffinity.getType());
        p << ")";
      },
      [&]() {
        p << ",";
        p.printNewline();
      });
  p.decreaseIndent();
  p.printNewline();
  p << "]";
}

//===----------------------------------------------------------------------===//
// custom<DescriptorType>($descriptor_type)
//===----------------------------------------------------------------------===//

// Custom parser/printer to omit the wrapping `<` and `>` unlike autogenerated
// attribute parser/printer.
static ParseResult parseDescriptorType(OpAsmParser &parser,
                                       DescriptorTypeAttr &dtAttr) {
  StringRef enumKeyword;
  if (failed(parser.parseKeyword(&enumKeyword)))
    return failure();
  std::optional<DescriptorType> maybeEnum =
      symbolizeDescriptorType(enumKeyword);
  if (!maybeEnum)
    return failure();
  dtAttr = DescriptorTypeAttr::get(parser.getContext(), *maybeEnum);
  return success();
}

static void printDescriptorType(OpAsmPrinter &p, Operation *,
                                DescriptorTypeAttr dtAttr) {
  p << stringifyDescriptorType(dtAttr.getValue());
}

//===----------------------------------------------------------------------===//
// custom<PipelineBindings>($binding_ordinals,
//                          $binding_buffers,
//                          type($binding_buffers),
//                          $binding_offsets,
//                          $binding_lengths)
//===----------------------------------------------------------------------===//

static ParseResult parsePipelineBindings(
    OpAsmParser &parser,
    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &ordinals,
    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &buffers,
    SmallVectorImpl<Type> &bufferTypes,
    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &bufferOffsets,
    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &bufferLengths) {
  do {
    OpAsmParser::UnresolvedOperand ordinal;
    OpAsmParser::UnresolvedOperand buffer;
    Type bufferType;
    OpAsmParser::UnresolvedOperand bufferOffset;
    OpAsmParser::UnresolvedOperand bufferLength;
    if (failed(parser.parseOperand(ordinal)) || failed(parser.parseEqual()) ||
        failed(parser.parseLParen()) || failed(parser.parseOperand(buffer)) ||
        failed(parser.parseColonType(bufferType)) ||
        failed(parser.parseRParen()) || failed(parser.parseLSquare()) ||
        failed(parser.parseOperand(bufferOffset)) ||
        failed(parser.parseComma()) ||
        failed(parser.parseOperand(bufferLength)) ||
        failed(parser.parseRSquare())) {
      return failure();
    }
    ordinals.push_back(ordinal);
    buffers.push_back(buffer);
    bufferTypes.push_back(bufferType);
    bufferOffsets.push_back(bufferOffset);
    bufferLengths.push_back(bufferLength);
  } while (succeeded(parser.parseOptionalComma()));
  return success();
}

static void printPipelineBindings(OpAsmPrinter &p, Operation *op,
                                  ValueRange ordinals, ValueRange buffers,
                                  TypeRange bufferTypes,
                                  ValueRange bufferOffsets,
                                  ValueRange bufferLengths) {
  llvm::interleaveComma(llvm::zip_equal(ordinals, buffers, bufferTypes,
                                        bufferOffsets, bufferLengths),
                        p,
                        [&](std::tuple<Value, Value, Type, Value, Value> it) {
                          p.printNewline();
                          p << "  ";
                          p.printOperand(std::get<0>(it));
                          p << " = (";
                          p.printOperand(std::get<1>(it));
                          p << " : ";
                          p.printType(std::get<2>(it));
                          p << ")[";
                          p.printOperand(std::get<3>(it));
                          p << ", ";
                          p.printOperand(std::get<4>(it));
                          p << "]";
                        });
  p.printNewline();
}

//===----------------------------------------------------------------------===//
// custom<Bindings>($binding_buffers,
//                  type($binding_buffers),
//                  $binding_offsets,
//                  $binding_lengths)
//===----------------------------------------------------------------------===//

static ParseResult
parseBindings(OpAsmParser &parser,
              SmallVectorImpl<OpAsmParser::UnresolvedOperand> &buffers,
              SmallVectorImpl<Type> &bufferTypes,
              SmallVectorImpl<OpAsmParser::UnresolvedOperand> &bufferOffsets,
              SmallVectorImpl<OpAsmParser::UnresolvedOperand> &bufferLengths) {
  do {
    OpAsmParser::UnresolvedOperand buffer;
    Type bufferType;
    OpAsmParser::UnresolvedOperand bufferOffset;
    OpAsmParser::UnresolvedOperand bufferLength;
    if (failed(parser.parseLParen()) || failed(parser.parseOperand(buffer)) ||
        failed(parser.parseColonType(bufferType)) ||
        failed(parser.parseRParen()) || failed(parser.parseLSquare()) ||
        failed(parser.parseOperand(bufferOffset)) ||
        failed(parser.parseComma()) ||
        failed(parser.parseOperand(bufferLength)) ||
        failed(parser.parseRSquare())) {
      return failure();
    }
    buffers.push_back(buffer);
    bufferTypes.push_back(bufferType);
    bufferOffsets.push_back(bufferOffset);
    bufferLengths.push_back(bufferLength);
  } while (succeeded(parser.parseOptionalComma()));
  return success();
}

static void printBindings(OpAsmPrinter &p, Operation *op, ValueRange buffers,
                          TypeRange bufferTypes, ValueRange bufferOffsets,
                          ValueRange bufferLengths) {
  llvm::interleaveComma(
      llvm::zip_equal(buffers, bufferTypes, bufferOffsets, bufferLengths), p,
      [&](std::tuple<Value, Type, Value, Value> it) {
        p.printNewline();
        p << "  (";
        p.printOperand(std::get<0>(it));
        p << " : ";
        p.printType(std::get<1>(it));
        p << ")[";
        p.printOperand(std::get<2>(it));
        p << ", ";
        p.printOperand(std::get<3>(it));
        p << "]";
      });
  p.printNewline();
}

//===----------------------------------------------------------------------===//
// custom<BindingTable>($binding_buffers,
//                      type($binding_buffers),
//                      $binding_offsets,
//                      $binding_lengths)
//===----------------------------------------------------------------------===//

static ParseResult parseBindingTable(
    OpAsmParser &parser,
    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &buffers,
    SmallVectorImpl<Type> &bufferTypes,
    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &bufferOffsets,
    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &bufferLengths) {
  do {
    OpAsmParser::UnresolvedOperand buffer;
    Type bufferType;
    OpAsmParser::UnresolvedOperand bufferOffset;
    OpAsmParser::UnresolvedOperand bufferLength;
    if (failed(parser.parseLParen()) || failed(parser.parseOperand(buffer)) ||
        failed(parser.parseColonType(bufferType)) ||
        failed(parser.parseRParen()) || failed(parser.parseLSquare()) ||
        failed(parser.parseOperand(bufferOffset)) ||
        failed(parser.parseComma()) ||
        failed(parser.parseOperand(bufferLength)) ||
        failed(parser.parseRSquare())) {
      return failure();
    }
    buffers.push_back(buffer);
    bufferTypes.push_back(bufferType);
    bufferOffsets.push_back(bufferOffset);
    bufferLengths.push_back(bufferLength);
  } while (succeeded(parser.parseOptionalComma()));
  return success();
}

static void printBindingTable(OpAsmPrinter &p, Operation *op,
                              ValueRange buffers, TypeRange bufferTypes,
                              ValueRange bufferOffsets,
                              ValueRange bufferLengths) {
  llvm::interleaveComma(
      llvm::zip_equal(buffers, bufferTypes, bufferOffsets, bufferLengths), p,
      [&](std::tuple<Value, Type, Value, Value> it) {
        p.printNewline();
        p << "  ";
        p << "(";
        p.printOperand(std::get<0>(it));
        p << " : ";
        p.printType(std::get<1>(it));
        p << ")[";
        p.printOperand(std::get<2>(it));
        p << ", ";
        p.printOperand(std::get<3>(it));
        p << "]";
      });
  p.printNewline();
}

//===----------------------------------------------------------------------===//
// custom<TargetConditionRegion>($body)
//===----------------------------------------------------------------------===//

static FunctionType getTargetConditionRegionType(MLIRContext *context) {
  return FunctionType::get(context,
                           {
                               IREE::HAL::DeviceType::get(context),
                           },
                           {
                               IntegerType::get(context, 1),
                           });
}

static LogicalResult verifyTargetConditionRegion(Operation *op,
                                                 Region &region) {
  // Ignore if empty.
  if (region.empty()) {
    return success();
  }

  // Verify region takes a !hal.device.
  if (region.getNumArguments() != 1 ||
      !isa<IREE::HAL::DeviceType>(region.getArgumentTypes().front())) {
    return op->emitOpError()
           << "target condition region must take a !hal.device";
  }

  // Verify i1 return.
  for (auto returnOp : region.getOps<IREE::HAL::ReturnOp>()) {
    if (returnOp.getNumOperands() != 1) {
      return returnOp.emitOpError()
             << "target condition region must return a single i1 result";
    }
    for (auto returnType : returnOp.getOperandTypes()) {
      if (!returnType.isInteger(1)) {
        return returnOp.emitOpError()
               << "target condition region must return a single i1 result";
      }
    }
  }

  return success();
}

static ParseResult parseTargetConditionRegion(OpAsmParser &parser,
                                              Region &body) {
  SmallVector<OpAsmParser::Argument> args;
  if (failed(parser.parseArgumentList(args, AsmParser::Delimiter::Paren,
                                      /*allowType=*/true,
                                      /*allowAttrs=*/true))) {
    return failure();
  }

  SmallVector<Type> returnTypes;
  if (failed(parser.parseArrowTypeList(returnTypes))) {
    return failure();
  }
  if (returnTypes.size() != 1 ||
      !llvm::all_of(returnTypes, [](Type type) { return type.isInteger(1); })) {
    return parser.emitError(parser.getCurrentLocation())
           << "target condition region must return one i1";
  }

  return parser.parseRegion(body, args, /*enableNameShadowing=*/false);
}

static void printTargetConditionRegion(OpAsmPrinter &p, Operation *op,
                                       Region &body) {
  if (body.empty())
    return;
  p << "(";
  llvm::interleaveComma(body.getArguments(), p,
                        [&](BlockArgument arg) { p.printRegionArgument(arg); });
  p << ")";
  p.printArrowTypeList(TypeRange{IntegerType::get(body.getContext(), 1)});
  p << " ";
  p.printRegion(body, /*printEntryBlockArgs=*/false,
                /*printBlockTerminators=*/true);
}

static ParseResult parseTargetConditionObjects(
    OpAsmParser &parser, ArrayAttr &targetsAttr, ArrayAttr &targetOrdinalsAttr,
    ArrayAttr &targetObjectsAttr,
    SmallVector<std::unique_ptr<Region>, 2> &targetRegions) {
  SmallVector<Attribute> targetsAttrs;
  SmallVector<Attribute> targetOrdinalsAttrs;
  SmallVector<Attribute> targetObjectsAttrs;
  do {
    // #hal.executable.target<...>
    Attribute targetAttr;
    if (failed(parser.parseAttribute(targetAttr)))
      return failure();
    targetsAttrs.push_back(targetAttr);

    // if(...) -> i1 { ... }
    auto region = std::make_unique<Region>();
    if (succeeded(parser.parseOptionalKeyword("if"))) {
      if (failed(parseTargetConditionRegion(parser, *region)))
        return failure();
    }
    targetRegions.push_back(std::move(region));

    // ordinal(#)
    Attribute targetOrdinalAttr;
    if (failed(parser.parseKeyword("ordinal")) ||
        failed(parser.parseLParen()) ||
        failed(parser.parseAttribute(targetOrdinalAttr,
                                     IndexType::get(parser.getContext()))) ||
        failed(parser.parseRParen()))
      return failure();
    targetOrdinalsAttrs.push_back(targetOrdinalAttr);

    // = [#hal.executable.object<...>, ...]
    ArrayAttr targetObjectsAttr;
    if (failed(parser.parseEqual()) ||
        failed(parser.parseAttribute(targetObjectsAttr)))
      return failure();
    targetObjectsAttrs.push_back(targetObjectsAttr);
  } while (succeeded(parser.parseOptionalComma()));
  targetsAttr = ArrayAttr::get(parser.getContext(), targetsAttrs);
  targetOrdinalsAttr = ArrayAttr::get(parser.getContext(), targetOrdinalsAttrs);
  targetObjectsAttr = ArrayAttr::get(parser.getContext(), targetObjectsAttrs);
  return success();
}

static void printTargetConditionObjects(OpAsmPrinter &p, Operation *op,
                                        ArrayAttr targetsAttr,
                                        ArrayAttr targetOrdinalsAttr,
                                        ArrayAttr targetObjectsAttr,
                                        MutableArrayRef<Region> targetRegions) {
  p.increaseIndent();
  p.printNewline();

  llvm::interleave(
      llvm::zip_equal(targetsAttr, targetOrdinalsAttr, targetObjectsAttr,
                      targetRegions),
      [&](auto it) {
        auto &[targetAttr, targetOrdinalAttr, targetObjectsAttr, targetRegion] =
            it;
        p.printAttribute(targetAttr);
        if (!targetRegion.empty()) {
          p << " if";
          printTargetConditionRegion(p, op, targetRegion);
        }
        p << " ordinal(";
        p.printAttributeWithoutType(targetOrdinalAttr);
        p << ")";
        p << " = ";
        p.printAttribute(targetObjectsAttr);
      },
      [&]() {
        p << ",";
        p.printNewline();
      });

  p.decreaseIndent();
  p.printNewline();
}

//===----------------------------------------------------------------------===//
// custom<WorkgroupCountRegion>($body)
//===----------------------------------------------------------------------===//

static ParseResult parseWorkgroupCountRegion(OpAsmParser &parser,
                                             Region &body) {
  SmallVector<OpAsmParser::Argument> args;
  if (failed(parser.parseArgumentList(args, AsmParser::Delimiter::Paren,
                                      /*allowType=*/true,
                                      /*allowAttrs=*/true))) {
    return failure();
  }

  // Return types must be 3 dimensions (workgroup count XYZ).
  SmallVector<Type> returnTypes;
  if (failed(parser.parseArrowTypeList(returnTypes))) {
    return failure();
  }
  if (returnTypes.size() != 3 ||
      !llvm::all_of(returnTypes, [](Type type) { return type.isIndex(); })) {
    return parser.emitError(parser.getCurrentLocation())
           << "workgroup count region must return the XYZ dimension counts";
  }

  // Parse region contents.
  if (failed(parser.parseRegion(body, args, /*enableNameShadowing=*/false))) {
    return failure();
  }

  return success();
}

static void printWorkgroupCountRegion(OpAsmPrinter &p, Operation *op,
                                      Region &body) {
  if (body.empty())
    return;
  p << "(";
  llvm::interleaveComma(body.getArguments(), p,
                        [&](BlockArgument arg) { p.printRegionArgument(arg); });
  p << ")";
  Type indexType = IndexType::get(body.getContext());
  p.printArrowTypeList(TypeRange{indexType, indexType, indexType});
  p << " ";
  p.printRegion(body, /*printEntryBlockArgs=*/false,
                /*printBlockTerminators=*/true);
}

//===----------------------------------------------------------------------===//
// custom<ExportConditionRegion>($body)
//===----------------------------------------------------------------------===//

static ParseResult parseExportConditionRegion(OpAsmParser &parser,
                                              Region &body) {
  SmallVector<OpAsmParser::Argument> args;
  if (failed(parser.parseArgumentList(args, AsmParser::Delimiter::Paren,
                                      /*allowType=*/true,
                                      /*allowAttrs=*/true))) {
    return failure();
  }

  // Return types must be an i1.
  SmallVector<Type> returnTypes;
  if (failed(parser.parseArrowTypeList(returnTypes))) {
    return failure();
  }
  if (returnTypes.size() != 1 ||
      !llvm::all_of(returnTypes, [](Type type) { return type.isInteger(1); })) {
    return parser.emitError(parser.getCurrentLocation())
           << "condition region must return a boolean value";
  }

  // Parse region contents.
  return parser.parseRegion(body, args,
                            /*enableNameShadowing=*/false);
}

static void printExportConditionRegion(OpAsmPrinter &p, Operation *op,
                                       Region &body) {
  if (body.empty())
    return;
  p << "(";
  llvm::interleaveComma(body.getArguments(), p,
                        [&](BlockArgument arg) { p.printRegionArgument(arg); });
  p << ")";
  Type boolType = IntegerType::get(op->getContext(), 1);
  p.printArrowTypeList(TypeRange{boolType});
  p << " ";
  p.printRegion(body, /*printEntryBlockArgs=*/false,
                /*printBlockTerminators=*/true);
}

//===----------------------------------------------------------------------===//
// hal.ex.*
//===----------------------------------------------------------------------===//

void ExFileFromMemoryOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {
  setNameFn(getResult(), "memory_file");
}

//===----------------------------------------------------------------------===//
// hal.return
//===----------------------------------------------------------------------===//

LogicalResult ReturnOp::verify() {
  ReturnOp op = *this;

  auto parentFuncOp =
      dyn_cast_or_null<mlir::FunctionOpInterface>(op->getParentOp());
  if (parentFuncOp) {
    auto expectedTypes = parentFuncOp.getResultTypes();
    if (op.getNumOperands() != expectedTypes.size()) {
      return op.emitOpError() << "return must have the same number of operands "
                                 "as the parent result signature (have "
                              << op.getNumOperands() << ", expected "
                              << expectedTypes.size() << ")";
    }
    for (auto &&[index, values] :
         llvm::enumerate(llvm::zip_equal(op.getOperands(), expectedTypes))) {
      auto [operand, expectedType] = values;
      if (operand.getType() != expectedType) {
        return op.emitOpError()
               << "parent expected result " << index << " to be "
               << expectedType << " but returning " << operand.getType();
      }
    }
  }

  return success();
}

//===----------------------------------------------------------------------===//
// hal.tensor.import/export
//===----------------------------------------------------------------------===//

void TensorImportOp::build(OpBuilder &builder, OperationState &result,
                           Type resultType, Value source,
                           TypeAttr targetEncoding, bool consume,
                           StringAttr name, Attribute affinity) {
  build(builder, result, resultType, source, targetEncoding, consume,
        /*waitFence=*/Value{}, name, affinity);
}

void TensorImportOp::build(OpBuilder &builder, OperationState &result,
                           Type resultType, Value source,
                           TypeAttr targetEncoding, bool consume,
                           Value waitFence, StringAttr name,
                           Attribute affinity) {
  auto shapedType = llvm::cast<ShapedType>(resultType);
  assert((isa<IREE::HAL::BufferViewType>(source.getType()) ||
          shapedType.hasStaticShape()) &&
         "can only use this constructor for buffer views when shape "
         "information is required");
  SmallVector<Value> dynamicDims;
  for (int64_t i = 0; i < shapedType.getRank(); ++i) {
    if (!shapedType.isDynamicDim(i))
      continue;
    dynamicDims.push_back(builder.createOrFold<IREE::HAL::BufferViewDimOp>(
        result.location, builder.getIndexType(), source,
        builder.getIndexAttr(i)));
  }
  build(builder, result, resultType, source, targetEncoding, dynamicDims,
        consume ? builder.getUnitAttr() : UnitAttr{}, waitFence, name,
        affinity);
}

static LogicalResult verifyTypeStorageCompatibility(Operation *op,
                                                    Type encodingType,
                                                    Type storageType) {
  if (encodingType == storageType)
    return success();
  auto encodingShapedType = llvm::dyn_cast<ShapedType>(encodingType);
  auto storageShapedType = llvm::dyn_cast<ShapedType>(storageType);
  if (!encodingShapedType || !storageShapedType)
    return success();

  if (IREE::Util::getRoundedElementByteWidth(
          encodingShapedType.getElementType()) !=
      IREE::Util::getRoundedElementByteWidth(
          storageShapedType.getElementType())) {
    // TODO(benvanik): more sophisticated logic here. There are a lot of valid
    // cases that are difficult to account for here statically; for example,
    // packing 8xi1 into 1xi8 or complex<f32> into 2xf32. We could try to guess
    // the element count (at least the static part of it) and ensure the scaling
    // matches but that wouldn't account for user variance. Really with this op
    // we are letting the _user_ control the bitcasting and type reflection and
    // purposefully don't want to mess with it (users should be able to put
    // custom types here, etc).
    //
    // NOTE: we round to bytes first as the base type (such as i1) may not be
    // representable in an external form.
    // return op->emitOpError() << "encoding and storage types must be "
    //                             "bitcastable; adjusted encoding bit width "
    //                             "of "
    //                          << encodingShapedType.getElementTypeBitWidth()
    //                          << " != adjusted storage bit width of "
    //                          << storageShapedType.getElementTypeBitWidth();
  }

  if (encodingShapedType.getNumDynamicDims() !=
      storageShapedType.getNumDynamicDims()) {
    // NOTE: we implicitly require that the dimensions are equivalent but
    // dont actually care about their order. For example, tensor<?x1xf32> is
    // compatible with tensor<?xf32>.
    return op->emitOpError()
           << "encoding and storage types must have the same "
              "dynamic dimension values; encoding shape "
           << encodingShapedType << " incompatible with storage shape "
           << storageShapedType;
  }

  return success();
}

LogicalResult TensorImportOp::verify() {
  TensorImportOp op = *this;
  auto targetType = llvm::cast<TensorType>(op.getTarget().getType());
  if (targetType.getNumDynamicDims() != op.getTargetDims().size()) {
    return op->emitOpError() << "number of target_dims must match number of "
                                "dynamic dims in target type";
  }
  return verifyTypeStorageCompatibility(op, op.getTargetEncoding(), targetType);
}

void TensorExportOp::build(OpBuilder &builder, OperationState &result,
                           Type resultType, Value source,
                           TypeAttr sourceEncoding, StringAttr name,
                           Attribute affinity) {
  auto dynamicDims =
      IREE::Util::buildDynamicDimsForValue(result.location, source, builder);
  build(builder, result, resultType, source, sourceEncoding, dynamicDims, name,
        affinity);
}

LogicalResult TensorExportOp::verify() {
  TensorExportOp op = *this;
  auto sourceType = llvm::cast<TensorType>(op.getSource().getType());
  if (sourceType.getNumDynamicDims() != op.getSourceDims().size()) {
    return op->emitOpError() << "number of source_dims must match number of "
                                "dynamic dims in source type";
  }
  return verifyTypeStorageCompatibility(op, op.getSourceEncoding(),
                                        op.getSource().getType());
}

//===----------------------------------------------------------------------===//
// hal.tensor.alias
//===----------------------------------------------------------------------===//

Value TensorAliasOp::getTiedResult(unsigned resultIndex) {
  return IREE::Util::TiedOpInterface::findTiedBaseValue(getSource());
}

::std::optional<unsigned>
TensorAliasOp::getTiedResultOperandIndex(unsigned resultIndex) {
  return {0}; // source
}

SmallVector<int64_t> TensorAliasOp::getTiedResultOperandIndices() {
  return {0}; // source
}

LogicalResult TensorAliasOp::verify() {
  TensorAliasOp op = *this;
  auto type = llvm::cast<TensorType>(op.getSource().getType());
  if (type.getNumDynamicDims() != op.getSourceDims().size()) {
    return op->emitOpError()
           << "number of dynamic dims must match the operand type";
  }
  return success();
}

//===----------------------------------------------------------------------===//
// hal.tensor.barrier
//===----------------------------------------------------------------------===//

void TensorBarrierOp::build(OpBuilder &builder, OperationState &result,
                            ValueRange sources, Value signalFence) {
  auto resultTypes = llvm::map_to_vector(
      sources, [](Value source) { return source.getType(); });
  build(builder, result, resultTypes, sources, signalFence);
}

Value TensorBarrierOp::getTiedResult(unsigned resultIndex) {
  return IREE::Util::TiedOpInterface::findTiedBaseValue(
      getSources()[resultIndex]);
}

::std::optional<unsigned>
TensorBarrierOp::getTiedResultOperandIndex(unsigned resultIndex) {
  return {resultIndex}; // sources[i]
}

SmallVector<int64_t> TensorBarrierOp::getTiedResultOperandIndices() {
  size_t numSources = getSources().size();
  return llvm::to_vector(llvm::seq<int64_t>(0, numSources));
}

//===----------------------------------------------------------------------===//
// hal.dispatch.extern
//===----------------------------------------------------------------------===//

void DispatchExternOp::build(OpBuilder &builder, OperationState &state,
                             ValueRange workload, TypeRange resultTypes,
                             ValueRange resultDims, ValueRange arguments,
                             ValueRange argumentDims,
                             ArrayRef<int64_t> tiedOperands,
                             IREE::HAL::ExecutableObjectsAttr targetObjects,
                             ArrayRef<NamedAttribute> attributes) {
  state.addTypes(resultTypes);
  state.addOperands(workload);
  state.addOperands(arguments);
  state.addOperands(argumentDims);
  state.addOperands(resultDims);
  state.addAttributes(attributes);
  state.attributes.erase(IREE::Util::TiedOpInterface::getStorageAttrName());
  state.addAttribute(IREE::Util::TiedOpInterface::getStorageAttrName(),
                     builder.getIndexArrayAttr(tiedOperands));
  state.addAttribute("targets", targetObjects.getTargets());
  state.addAttribute("target_objects", targetObjects.getTargetObjects());
  state.attributes.erase(getOperandSegmentSizeAttr());
  state.addAttribute(getOperandSegmentSizeAttr(),
                     builder.getDenseI32ArrayAttr({
                         static_cast<int32_t>(workload.size()),
                         static_cast<int32_t>(arguments.size()),
                         static_cast<int32_t>(argumentDims.size()),
                         static_cast<int32_t>(resultDims.size()),
                     }));

  // NOTE: workgroup count region is empty; callers are expected to populate it.
  state.addRegion();

  // Add one empty region per target.
  for (size_t i = 0; i < targetObjects.getTargets().size(); ++i)
    state.addRegion();
}

// Verifies that |dynamicDims| contains the appropriate number of dims for all
// of the dynamic dimensions in |values|.
static LogicalResult verifyOpDynamicDims(Operation *op, ValueRange values,
                                         ValueRange dynamicDims) {
  unsigned requiredCount = 0;
  for (auto value : values) {
    if (auto shapedType = llvm::dyn_cast<ShapedType>(value.getType())) {
      requiredCount += shapedType.getNumDynamicDims();
    }
  }
  if (dynamicDims.size() != requiredCount) {
    return op->emitOpError()
           << "value set has " << requiredCount
           << " dynamic dimensions but only " << dynamicDims.size()
           << " dimension values are attached";
  }
  return success();
}

static LogicalResult verifyWorkgroupCountWorkload(Operation *op,
                                                  ValueRange workload,
                                                  Region &region) {
  // Verify the workload operands match the expected capture args.
  auto regionArguments =
      llvm::make_filter_range(region.getArgumentTypes(), [](Type type) {
        return !isa<IREE::HAL::DeviceType>(type);
      });
  if (workload.size() != llvm::range_size(regionArguments)) {
    return op->emitOpError()
           << "workload operands and workgroup count args mismatch ("
           << workload.size() << " vs " << llvm::range_size(regionArguments)
           << ")";
  }
  for (auto [index, values] :
       llvm::enumerate(llvm::zip_equal(workload, regionArguments))) {
    auto [workloadValue, capturedType] = values;
    if (workloadValue.getType() != capturedType) {
      return op->emitOpError()
             << "workload value " << index << " type mismatch; operand is "
             << workloadValue.getType() << " but region captures "
             << capturedType;
    }
  }
  return success();
}

// Verifies that the workgroup count region matches the expected
// signature. Returns success if the region is empty.
static LogicalResult verifyWorkgroupCountRegion(Operation *op, Region &region) {
  if (region.empty())
    return success();

  // Verify one of the supported signatures.
  bool validArguments = true;
  if (region.getNumArguments() == 0) {
    // Need at least a !hal.device.
    validArguments = false;
  } else if (!llvm::isa<IREE::HAL::DeviceType>(
                 region.getArgument(0).getType())) {
    // !hal.device must come first.
    validArguments = false;
  } else {
    // All remaining arguments need to be of type index (today).
    for (BlockArgument &blockArg : region.getArguments().drop_front(1)) {
      if (!llvm::isa<IndexType>(blockArg.getType())) {
        validArguments = false;
        break;
      }
    }
  }
  if (!validArguments) {
    return op->emitOpError(
        "expected workgroup_count to take (%device: !hal.device, "
        "%workload_0: index, %workload_1: index, ...");
  }

  // Verify the return types are XYZ index counts.
  for (auto returnOp : region.getOps<IREE::HAL::ReturnOp>()) {
    auto returnTypes = returnOp.getOperandTypes();
    if (returnTypes.size() != 3 ||
        !llvm::all_of(returnTypes, [](Type type) { return type.isIndex(); })) {
      return op->emitError(
          "workgroup count region must return the XYZ dimension counts as "
          "`index` types");
    }
  }

  return success();
}

LogicalResult DispatchExternOp::verify() {
  Operation *op = getOperation();

  if (failed(verifyOpDynamicDims(getOperation(), getArguments(),
                                 getArgumentDims())) ||
      failed(
          verifyOpDynamicDims(getOperation(), getResults(), getResultDims()))) {
    return failure();
  }

  auto verifyIOType = [&](Type type) -> LogicalResult {
    if (auto shapedType = llvm::dyn_cast<ShapedType>(type)) {
      if (shapedType.getElementType().isIndex()) {
        return op->emitOpError() << "I/O type " << type
                                 << " is invalid: index types must not cross "
                                    "the dispatch boundary";
      }
    }
    return success();
  };
  for (auto type : getOperandTypes()) {
    if (failed(verifyIOType(type)))
      return failure();
  }
  for (auto type : getResultTypes()) {
    if (failed(verifyIOType(type)))
      return failure();
  }

  if (failed(verifyWorkgroupCountRegion(op, getWorkgroupCount()))) {
    return failure();
  } else if (failed(verifyWorkgroupCountWorkload(op, getWorkload(),
                                                 getWorkgroupCount()))) {
    return failure();
  }

  if (getTargets().size() != getTargetObjects().size()) {
    return op->emitOpError() << "target and objects arrays must match";
  }
  if (getTargets().size() != getTargetRegions().size()) {
    return op->emitOpError()
           << "target and condition regions must match (but they may be empty)";
  }
  for (auto &targetRegion : getTargetRegions()) {
    if (failed(verifyTargetConditionRegion(op, targetRegion))) {
      return failure();
    }
  }

  return success();
}

std::pair<unsigned, unsigned>
DispatchExternOp::getTiedOperandsIndexAndLength() {
  return getODSOperandIndexAndLength(1);
}

//===----------------------------------------------------------------------===//
// hal.device.memoize
//===----------------------------------------------------------------------===//

void DeviceMemoizeOp::build(OpBuilder &builder, OperationState &state,
                            TypeRange resultTypes, Value device,
                            Value queueAffinity,
                            ArrayRef<NamedAttribute> attributes) {
  state.addTypes(resultTypes);
  state.addOperands(device);
  state.addOperands(queueAffinity);
  state.addAttributes(attributes);
  state.addRegion();
}

void DeviceMemoizeOp::getSuccessorRegions(
    RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
  // Unconditional control flow into the region and back to the parent, so
  // return the correct RegionSuccessor purely based on the index being None or
  // 0.
  if (!point.isParent()) {
    regions.push_back(RegionSuccessor({}));
  } else {
    regions.push_back(RegionSuccessor(&getBody(), getBody().getArguments()));
  }
}

//===----------------------------------------------------------------------===//
// hal.allocator.select
//===----------------------------------------------------------------------===//

void AllocatorSelectOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {
  setNameFn(getSelectedDevice(), "device");
  setNameFn(getSelectedQueueAffinity(), "queue_affinity");
}

//===----------------------------------------------------------------------===//
// hal.allocator.allocate
//===----------------------------------------------------------------------===//

void AllocatorAllocateOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {
  setNameFn(getResult(), "buffer");
}

Value AllocatorAllocateOp::getOperandSize(unsigned idx) { return {}; }

Value AllocatorAllocateOp::getResultSize(unsigned idx) {
  return getResultSize();
}

//===----------------------------------------------------------------------===//
// hal.allocator.import
//===----------------------------------------------------------------------===//

void AllocatorImportOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {
  setNameFn(getDidImport(), "did_import");
  setNameFn(getResult(), "mapped");
}

Value AllocatorImportOp::getOperandSize(unsigned idx) { return {}; }

Value AllocatorImportOp::getResultSize(unsigned idx) { return getLength(); }

//===----------------------------------------------------------------------===//
// hal.allocator.resolve_memory_properties
//===----------------------------------------------------------------------===//

void AllocatorResolveMemoryPropertiesOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {
  setNameFn(getResult(0), "memory_types");
  setNameFn(getResult(1), "buffer_usage");
}

//===----------------------------------------------------------------------===//
// hal.buffer.allocation.discard
//===----------------------------------------------------------------------===//

void BufferAllocationDiscardOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {
  setNameFn(getResult(), "was_terminal");
}

//===----------------------------------------------------------------------===//
// hal.buffer.allocation.is_terminal
//===----------------------------------------------------------------------===//

void BufferAllocationIsTerminalOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {
  setNameFn(getResult(), "is_terminal");
}

//===----------------------------------------------------------------------===//
// hal.buffer.subspan
//===----------------------------------------------------------------------===//

void BufferSubspanOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {
  setNameFn(getResult(), "buffer");
}

Value BufferSubspanOp::getOperandSize(unsigned idx) { return getLength(); }

Value BufferSubspanOp::getResultSize(unsigned idx) { return getLength(); }

//===----------------------------------------------------------------------===//
// hal.buffer.byte_length
//===----------------------------------------------------------------------===//

void BufferLengthOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {
  setNameFn(getResult(), "length");
}

//===----------------------------------------------------------------------===//
// hal.buffer_usage
//===----------------------------------------------------------------------===//

void BufferUsageOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {
  setNameFn(getResult(), "buffer_usage");
}

//===----------------------------------------------------------------------===//
// hal.memory_type
//===----------------------------------------------------------------------===//

void MemoryTypeOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {
  setNameFn(getResult(), "memory_type");
}

//===----------------------------------------------------------------------===//
// hal.element_type
//===----------------------------------------------------------------------===//

// Keep these in sync with iree/hal/buffer_view.h
enum class NumericalType : uint32_t {
  kUnknown = 0x00,
  kInteger = 0x10,
  kIntegerSigned = kInteger | 0x01,
  kIntegerUnsigned = kInteger | 0x02,
  kBoolean = kInteger | 0x03,
  kFloat = 0x20,
  kFloatIEEE = kFloat | 0x01,
  kFloatBrain = kFloat | 0x02,
  kFloatComplex = kFloat | 0x03,
  kFloat8E5M2 = kFloat | 0x04,
  kFloat8E4M3FN = kFloat | 0x05,
  kFloat8E5M2FNUZ = kFloat | 0x06,
  kFloat8E4M3FNUZ = kFloat | 0x07,
  kFloat8E8M0FNU = kFloat | 0x08,
};

constexpr inline int32_t makeElementTypeValue(NumericalType numericalType,
                                              int32_t bitCount) {
  return (static_cast<uint32_t>(numericalType) << 24) | bitCount;
}

// static
std::optional<int32_t> ElementTypeOp::getTypeValue(Type type) {
  if (auto intType = llvm::dyn_cast_if_present<IntegerType>(type)) {
    NumericalType numericalType;
    if (intType.isInteger(1)) {
      return makeElementTypeValue(NumericalType::kBoolean, 8);
    } else if (intType.isSigned()) {
      numericalType = NumericalType::kIntegerSigned;
    } else if (intType.isUnsigned()) {
      numericalType = NumericalType::kIntegerUnsigned;
    } else {
      // There's no such thing as a signless integer in machine types but we
      // need to be able to round-trip the format through the ABI. Exact
      // numerical type equality comparisons may fail if the frontend assumes
      // signed/unsigned but the compiler is propagating signless.
      numericalType = NumericalType::kInteger;
    }
    return makeElementTypeValue(numericalType, intType.getWidth());
  } else if (auto floatType = llvm::dyn_cast_if_present<FloatType>(type)) {
    switch (APFloat::SemanticsToEnum(floatType.getFloatSemantics())) {
    case APFloat::S_Float8E5M2:
      return makeElementTypeValue(NumericalType::kFloat8E5M2, 8);
    case APFloat::S_Float8E4M3FN:
      return makeElementTypeValue(NumericalType::kFloat8E4M3FN, 8);
    case APFloat::S_Float8E5M2FNUZ:
      return makeElementTypeValue(NumericalType::kFloat8E5M2FNUZ, 8);
    case APFloat::S_Float8E4M3FNUZ:
      return makeElementTypeValue(NumericalType::kFloat8E4M3FNUZ, 8);
    case APFloat::S_Float8E8M0FNU:
      return makeElementTypeValue(NumericalType::kFloat8E8M0FNU, 8);
    case APFloat::S_IEEEhalf:
    case APFloat::S_IEEEsingle:
    case APFloat::S_IEEEdouble:
    case APFloat::S_IEEEquad:
      return makeElementTypeValue(NumericalType::kFloatIEEE,
                                  floatType.getWidth());
    case APFloat::S_BFloat:
      return makeElementTypeValue(NumericalType::kFloatBrain,
                                  floatType.getWidth());
    default:
      return std::nullopt;
    }
  } else if (auto complexType = llvm::dyn_cast_if_present<ComplexType>(type)) {
    return makeElementTypeValue(
        NumericalType::kFloatComplex,
        complexType.getElementType().getIntOrFloatBitWidth() * 2);
  }
  return std::nullopt;
}

void ElementTypeOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {
  // We could make this match the C names.
  std::string name;
  llvm::raw_string_ostream os(name);
  os << "element_type_";
  os << getTypeAttr();
  setNameFn(getResult(), name);
}

LogicalResult ElementTypeOp::verify() {
  ElementTypeOp op = *this;
  auto value = getTypeValue(getTypeAttr().getValue());
  if (!value.has_value()) {
    return op.emitOpError("unsupported element type");
  }
  return success();
}

//===----------------------------------------------------------------------===//
// hal.encoding_type
//===----------------------------------------------------------------------===//

// static
std::optional<int32_t> EncodingTypeOp::getTypeValue(Attribute attr) {
  // TODO(#6762): encoding attribute handling/mapping to enums.
  if (attr)
    return std::nullopt;
  // Default to IREE_HAL_ENCODING_TYPE_DENSE_ROW_MAJOR for now.
  return 1;
}

void EncodingTypeOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {
  if (!getEncodingAttr())
    setNameFn(getResult(), "dense_row_major");
}

LogicalResult EncodingTypeOp::verify() {
  EncodingTypeOp op = *this;
  auto value = getTypeValue(getEncodingAttr());
  if (!value.has_value()) {
    return op.emitOpError("unsupported encoding type");
  }
  return success();
}

//===----------------------------------------------------------------------===//
// hal.buffer_view.create
//===----------------------------------------------------------------------===//

void BufferViewCreateOp::build(OpBuilder &builder, OperationState &state,
                               Value sourceBuffer, Value sourceOffset,
                               Value sourceLength, int32_t elementType,
                               int32_t encodingType, ValueRange shape) {
  build(builder, state, sourceBuffer, sourceOffset, sourceLength,
        builder.createOrFold<arith::ConstantIntOp>(state.location, elementType,
                                                   32),
        builder.createOrFold<arith::ConstantIntOp>(state.location, encodingType,
                                                   32),
        shape);
}

void BufferViewCreateOp::build(OpBuilder &builder, OperationState &state,
                               Value sourceBuffer, Value sourceOffset,
                               Value sourceLength, Value elementType,
                               Value encodingType, ValueRange shape) {
  state.addOperands(
      {sourceBuffer, sourceOffset, sourceLength, elementType, encodingType});
  state.addOperands(shape);
  state.addTypes({BufferViewType::get(builder.getContext())});
}

void BufferViewCreateOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {
  setNameFn(getResult(), "view");
}

//===----------------------------------------------------------------------===//
// hal.buffer_view.buffer
//===----------------------------------------------------------------------===//

void BufferViewBufferOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {
  setNameFn(getResult(), "buffer");
}

//===----------------------------------------------------------------------===//
// hal.buffer_view.dim
//===----------------------------------------------------------------------===//

void BufferViewDimOp::inferResultRangesFromOptional(
    ArrayRef<IntegerValueRange> argRanges, SetIntLatticeFn setResultRange) {
  const unsigned indexTypeNumBits = 64;
  setResultRange(getResult(), IntegerValueRange(ConstantIntRanges::fromUnsigned(
                                  APInt::getZero(indexTypeNumBits),
                                  APInt(indexTypeNumBits, MAX_DIM_VALUE))));
}

//===----------------------------------------------------------------------===//
// hal.buffer_view.dim
//===----------------------------------------------------------------------===//

void BufferViewRankOp::inferResultRangesFromOptional(
    ArrayRef<IntegerValueRange> argRanges, SetIntLatticeFn setResultRange) {
  const unsigned indexTypeNumBits = 64;
  setResultRange(getResult(), IntegerValueRange(ConstantIntRanges::fromUnsigned(
                                  APInt::getZero(indexTypeNumBits),
                                  APInt(indexTypeNumBits, MAX_RANK_VALUE))));
}

//===----------------------------------------------------------------------===//
// hal.channel.create
//===----------------------------------------------------------------------===//

void ChannelCreateOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {
  setNameFn(getResult(), "channel");
}

//===----------------------------------------------------------------------===//
// hal.channel.split
//===----------------------------------------------------------------------===//

void ChannelSplitOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {
  setNameFn(getResult(), "channel");
}

//===----------------------------------------------------------------------===//
// hal.channel.rank_and_count
//===----------------------------------------------------------------------===//

void ChannelRankAndCountOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {
  setNameFn(getRank(), "ccl_rank");
  setNameFn(getCount(), "ccl_count");
}

//===----------------------------------------------------------------------===//
// hal.command_buffer.create
//===----------------------------------------------------------------------===//

void CommandBufferCreateOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {
  setNameFn(getResult(), "cmd");
}

//===----------------------------------------------------------------------===//
// hal.command_buffer.update_buffer
//===----------------------------------------------------------------------===//

IREE::Util::SubrangeOperand
CommandBufferUpdateBufferOp::getSubrangeOperand(unsigned operandIndex) {
  if (operandIndex == 1) {
    return IREE::Util::SubrangeOperand{getSourceBuffer(), getSourceSize(),
                                       getSourceOffset(), getLength()};
  } else {
    assert(false && "only source is a subrange");
    return {};
  }
}

void CommandBufferUpdateBufferOp::setSubrangeOperand(
    unsigned operandIndex, IREE::Util::SubrangeOperand operand) {
  if (operandIndex == 1) {
    getSourceBufferMutable().assign(operand.resource);
    getSourceSizeMutable().assign(operand.resourceSize);
    getSourceOffsetMutable().assign(operand.offset);
  } else {
    assert(false && "only source is a subrange");
  }
}

//===----------------------------------------------------------------------===//
// hal.command_buffer.dispatch + .indirect
//===----------------------------------------------------------------------===//

void CommandBufferDispatchOp::build(OpBuilder &builder, OperationState &state,
                                    Value commandBuffer, Value executable,
                                    Value entryPoint, ValueRange workgroups,
                                    ValueRange constants,
                                    ArrayRef<BindingValue> bindings,
                                    IREE::HAL::DispatchFlags flags) {
  state.addOperands({commandBuffer, executable, entryPoint});
  state.addOperands(workgroups);
  state.addOperands(constants);
  SmallVector<Value> bindingBuffers;
  SmallVector<Value> bindingOffsets;
  SmallVector<Value> bindingLengths;
  for (auto binding : bindings) {
    bindingBuffers.push_back(binding.buffer);
    bindingOffsets.push_back(binding.byteOffset);
    bindingLengths.push_back(binding.byteLength);
  }
  state.addOperands(bindingBuffers);
  state.addOperands(bindingOffsets);
  state.addOperands(bindingLengths);
  state.addAttribute("flags",
                     builder.getAttr<IREE::HAL::DispatchFlagsAttr>(flags));
  state.addAttribute(getOperandSegmentSizeAttr(),
                     builder.getDenseI32ArrayAttr({
                         1,
                         1,
                         1,
                         1,
                         1,
                         1,
                         static_cast<int32_t>(constants.size()),
                         static_cast<int32_t>(bindingBuffers.size()),
                         static_cast<int32_t>(bindingOffsets.size()),
                         static_cast<int32_t>(bindingLengths.size()),
                     }));
}

void CommandBufferDispatchIndirectOp::build(
    OpBuilder &builder, OperationState &state, Value commandBuffer,
    Value executable, Value entryPoint, Value workgroupsBuffer,
    Value workgroupsOffset, ValueRange constants,
    ArrayRef<BindingValue> bindings, IREE::HAL::DispatchFlags flags) {
  state.addOperands({commandBuffer, executable, entryPoint, workgroupsBuffer,
                     workgroupsOffset});
  state.addOperands(constants);
  SmallVector<Value> bindingBuffers;
  SmallVector<Value> bindingOffsets;
  SmallVector<Value> bindingLengths;
  for (auto binding : bindings) {
    bindingBuffers.push_back(binding.buffer);
    bindingOffsets.push_back(binding.byteOffset);
    bindingLengths.push_back(binding.byteLength);
  }
  state.addOperands(bindingBuffers);
  state.addOperands(bindingOffsets);
  state.addOperands(bindingLengths);
  state.addAttribute("flags",
                     builder.getAttr<IREE::HAL::DispatchFlagsAttr>(flags));
  state.addAttribute(getOperandSegmentSizeAttr(),
                     builder.getDenseI32ArrayAttr({
                         1,
                         1,
                         1,
                         1,
                         1,
                         static_cast<int32_t>(constants.size()),
                         static_cast<int32_t>(bindingBuffers.size()),
                         static_cast<int32_t>(bindingOffsets.size()),
                         static_cast<int32_t>(bindingLengths.size()),
                     }));
}

static LogicalResult verifyDispatchBindings(Operation *op,
                                            ValueRange bindingBuffers,
                                            ValueRange bindingOffsets,
                                            ValueRange bindingLengths) {
  if (bindingBuffers.size() != bindingOffsets.size() ||
      bindingBuffers.size() != bindingLengths.size()) {
    return op->emitOpError() << "requires that binding fields all have the "
                                "same number of elements";
  }
  return success();
}

LogicalResult CommandBufferDispatchOp::verify() {
  CommandBufferDispatchOp op = *this;
  return verifyDispatchBindings(op, op.getBindingBuffers(),
                                op.getBindingOffsets(), op.getBindingLengths());
}

LogicalResult CommandBufferDispatchIndirectOp::verify() {
  CommandBufferDispatchIndirectOp op = *this;
  return verifyDispatchBindings(op, op.getBindingBuffers(),
                                op.getBindingOffsets(), op.getBindingLengths());
}

//===----------------------------------------------------------------------===//
// hal.device.resolve
//===----------------------------------------------------------------------===//

void DeviceResolveOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {
  for (auto result : getResults()) {
    if (isa<IREE::HAL::DeviceType>(result.getType())) {
      setNameFn(result, "device");
    } else if (isa<IREE::HAL::AllocatorType>(result.getType())) {
      setNameFn(result, "allocator");
    } else if (isa<IntegerType>(result.getType())) {
      setNameFn(result, "queue_affinity");
    }
  }
}

//===----------------------------------------------------------------------===//
// hal.device.allocator
//===----------------------------------------------------------------------===//

void DeviceAllocatorOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {
  setNameFn(getResult(), "allocator");
}

//===----------------------------------------------------------------------===//
// hal.device.query
//===----------------------------------------------------------------------===//

LogicalResult DeviceQueryOp::verify() {
  DeviceQueryOp op = *this;
  if (op.getDefaultValue().has_value()) {
    if (op.getDefaultValue()->getType() != op.getValue().getType()) {
      return op.emitOpError()
             << "type mismatch between result and default value";
    }
  }
  return success();
}

// static
Value DeviceQueryOp::createI1(Location loc, Value device, StringRef category,
                              StringRef key, OpBuilder &builder) {
  auto i1Type = builder.getI1Type();
  return IREE::HAL::DeviceQueryOp::create(builder, loc, i1Type, i1Type, device,
                                          builder.getStringAttr(category),
                                          builder.getStringAttr(key),
                                          builder.getIntegerAttr(i1Type, 0))
      .getValue();
}

//===----------------------------------------------------------------------===//
// hal.device.queue.*
//===----------------------------------------------------------------------===//

void DeviceQueueAllocaOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {
  setNameFn(getResult(), "transient_buffer");
}

Value DeviceQueueAllocaOp::getOperandSize(unsigned idx) { return {}; }

Value DeviceQueueAllocaOp::getResultSize(unsigned idx) {
  return getResultSize();
}

static LogicalResult verifyDeviceQueueFences(Operation *queueOp,
                                             Value waitFence,
                                             Value signalFence) {
  if (waitFence == signalFence &&
      !isa<IREE::Util::NullOp>(waitFence.getDefiningOp())) {
    return queueOp->emitOpError() << "device queue operations cannot wait and "
                                     "signal on the same fence.";
  }
  return success();
}

LogicalResult DeviceQueueAllocaOp::verify() {
  return verifyDeviceQueueFences(*this, getWaitFence(), getSignalFence());
}

LogicalResult DeviceQueueDeallocaOp::verify() {
  return verifyDeviceQueueFences(*this, getWaitFence(), getSignalFence());
}

LogicalResult DeviceQueueFillOp::verify() {
  return verifyDeviceQueueFences(*this, getWaitFence(), getSignalFence());
}

LogicalResult DeviceQueueUpdateOp::verify() {
  return verifyDeviceQueueFences(*this, getWaitFence(), getSignalFence());
}

LogicalResult DeviceQueueCopyOp::verify() {
  return verifyDeviceQueueFences(*this, getWaitFence(), getSignalFence());
}

LogicalResult DeviceQueueReadOp::verify() {
  return verifyDeviceQueueFences(*this, getWaitFence(), getSignalFence());
}

LogicalResult DeviceQueueWriteOp::verify() {
  return verifyDeviceQueueFences(*this, getWaitFence(), getSignalFence());
}

LogicalResult DeviceQueueBarrierOp::verify() {
  return verifyDeviceQueueFences(*this, getWaitFence(), getSignalFence());
}

LogicalResult DeviceQueueExecuteOp::verify() {
  return verifyDeviceQueueFences(*this, getWaitFence(), getSignalFence());
}

void DeviceQueueExecuteIndirectOp::build(OpBuilder &builder,
                                         OperationState &state, Value device,
                                         Value queueAffinity, Value waitFence,
                                         Value signalFence, Value commandBuffer,
                                         ArrayRef<BindingValue> bindings,
                                         IREE::HAL::ExecuteFlagBitfield flags) {
  state.addOperands(
      {device, queueAffinity, waitFence, signalFence, commandBuffer});
  SmallVector<Value> bindingBuffers;
  SmallVector<Value> bindingOffsets;
  SmallVector<Value> bindingLengths;
  for (auto binding : bindings) {
    bindingBuffers.push_back(binding.buffer);
    bindingOffsets.push_back(binding.byteOffset);
    bindingLengths.push_back(binding.byteLength);
  }
  state.addOperands(bindingBuffers);
  state.addOperands(bindingOffsets);
  state.addOperands(bindingLengths);
  state.addAttribute(
      "flags", builder.getAttr<IREE::HAL::ExecuteFlagBitfieldAttr>(flags));
}

LogicalResult DeviceQueueExecuteIndirectOp::verify() {
  return verifyDeviceQueueFences(*this, getWaitFence(), getSignalFence());
}

//===----------------------------------------------------------------------===//
// hal.devices.*
//===----------------------------------------------------------------------===//

void DevicesCountOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {
  setNameFn(getResult(), "device_count");
}

void DevicesGetOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {
  APInt index;
  if (matchPattern(getIndex(), m_ConstantInt(&index))) {
    llvm::SmallString<16> str("device_");
    index.toStringUnsigned(str);
    setNameFn(getResult(), str);
  } else {
    setNameFn(getResult(), "device_n");
  }
}

//===----------------------------------------------------------------------===//
// hal.executable.source
//===----------------------------------------------------------------------===//

LogicalResult ExecutableSourceOp::verify() {
  ExecutableSourceOp op = *this;

  auto conditionOps = getOps<IREE::HAL::ExecutableConditionOp>();
  if (llvm::range_size(conditionOps) > 1)
    return op.emitOpError()
           << "only one condition op is allowed in an executable";

  return success();
}

//===----------------------------------------------------------------------===//
// hal.executable
//===----------------------------------------------------------------------===//

void ExecutableOp::build(OpBuilder &builder, OperationState &state,
                         StringRef name) {
  ensureTerminator(*state.addRegion(), builder, state.location);
  state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
                     builder.getStringAttr(name));
}

LogicalResult ExecutableOp::verify() {
  // TODO(benvanik): check export name conflicts.
  return success();
}

//===----------------------------------------------------------------------===//
// hal.executable.export
//===----------------------------------------------------------------------===//

// Verifies that the export condition region matches the expected
// signature. Returns success if the region is empty.
static LogicalResult verifyExportConditionRegion(Operation *op,
                                                 Region &region) {
  if (region.empty())
    return success();

  // Verify one of the supported signatures.
  bool validArguments = true;
  if (region.getNumArguments() == 0) {
    // Need at least a !hal.device.
    validArguments = false;
  } else if (!llvm::isa<IREE::HAL::DeviceType>(
                 region.getArgument(0).getType())) {
    // !hal.device must come first.
    validArguments = false;
  } else {
    // All remaining arguments need to be of type index (today).
    for (BlockArgument &blockArg : region.getArguments().drop_front(1)) {
      if (!llvm::isa<IndexType>(blockArg.getType())) {
        validArguments = false;
        break;
      }
    }
  }
  if (!validArguments) {
    return op->emitOpError(
        "expected condition region to take (%device: !hal.device, "
        "%workload_0: index, %workload_1: index, ...");
  }

  // Verify the return type is i1.
  for (auto returnOp : region.getOps<IREE::HAL::ReturnOp>()) {
    auto returnTypes = returnOp.getOperandTypes();
    if (returnTypes.size() != 1 || !llvm::all_of(returnTypes, [](Type type) {
          return type.isInteger(1);
        })) {
      return op->emitError("condition region must return a boolean value");
    }
  }

  return success();
}

LogicalResult ExecutableExportOp::verify() {
  ExecutableExportOp op = *this;

  if (getConditionBody()) {
    if (!llvm::hasSingleElement(getCondition())) {
      return op.emitOpError()
             << "expected a single region block for the condition";
    } else if (failed(verifyExportConditionRegion(op, getCondition()))) {
      return failure();
    } else if (!op.getConditionFallbackAttr()) {
      return op.emitOpError()
             << "must have a fallback if a condition region is defined";
    }
  } else if (op.getConditionFallbackAttr()) {
    return op.emitOpError()
           << "fallback must only be present if a condition region is defined";
  }

  if (getWorkgroupCountBody()) {
    if (!llvm::hasSingleElement(getWorkgroupCount())) {
      return op.emitOpError()
             << "expected a single region block for the workgroup count";
    } else if (failed(verifyWorkgroupCountRegion(op, getWorkgroupCount()))) {
      return failure();
    }
  }

  return success();
}

// Returns true if the given argument type lists are equal.
static bool compareArgumentTypes(Block *lhs, Block *rhs) {
  auto lhsTypes = lhs->getArgumentTypes();
  auto rhsTypes = rhs->getArgumentTypes();
  if (lhsTypes.size() != rhsTypes.size()) {
    return false; // count mismatch
  }
  return llvm::equal(lhsTypes, rhsTypes);
}

LogicalResult
ExecutableExportOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
  if (auto fallbackAttr = getConditionFallbackAttr()) {
    // Ensure the fallback is defined.
    auto fallbackOp =
        symbolTable.lookupNearestSymbolFrom<IREE::HAL::ExecutableExportOp>(
            *this, fallbackAttr);
    if (!fallbackOp) {
      return emitOpError() << "undefined fallback entry point: "
                           << fallbackAttr;
    }

    // Layouts must match exactly.
    if (getLayout() != fallbackOp.getLayout()) {
      return emitOpError() << "fallback layout does not match (base has "
                           << getLayout() << ", fallback has "
                           << fallbackOp.getLayout() << ")";
    }

    // Workgroup count signature and condition signatures must match to allow
    // us to chain them during materialization.
    if (getConditionBody() && fallbackOp.getConditionBody()) {
      if (!compareArgumentTypes(getConditionBody(),
                                fallbackOp.getConditionBody())) {
        return emitOpError() << "fallback condition argument mismatch; "
                                "fallback args must match exactly";
      }
    }
    if (getWorkgroupCountBody() && fallbackOp.getWorkgroupCountBody()) {
      if (!compareArgumentTypes(getWorkgroupCountBody(),
                                fallbackOp.getWorkgroupCountBody())) {
        return emitOpError() << "fallback workgroup count argument mismatch; "
                                "fallback args must match exactly";
      }
    }
  }
  return success();
}

Value ExecutableExportOp::calculateCondition(Location loc, Value device,
                                             ValueRange workload,
                                             OpBuilder &builder) {
  // Always evaluate to true if no region is present.
  auto *body = getConditionBody();
  if (!body) {
    return arith::ConstantIntOp::create(builder, loc, 1, 1);
  }

  // TODO(benvanik): replace with region inlining util.
  IRMapping bvm;
  bvm.map(body->getArgument(0), device);
  // For now use the number of args to minimum of number of args used by
  // the body, and number of workload entries. When there is a more explicit
  // propagation of number of workload entries to the `hal.executable.variant`
  // this will be the same by construction.
  unsigned numArgs =
      std::min<unsigned>(body->getNumArguments() - 1, workload.size());
  for (unsigned argNum : llvm::seq<unsigned>(0, numArgs)) {
    bvm.map(body->getArgument(/*device*/ 1 + argNum), workload[argNum]);
  }
  for (Operation &op : body->without_terminator()) {
    builder.clone(op, bvm);
  }
  auto returnOp = cast<IREE::HAL::ReturnOp>(body->getTerminator());
  assert(returnOp.getNumOperands() == 1 && "must return bool");
  return bvm.lookup(returnOp.getOperands()[0]);
}

// Calculates the workgroup count (x, y, z) given the total N-dimensional
// |workload| and specific |workgroupSize|.
static std::array<Value, 3>
calculateWorkloadWorkgroupCount(Location loc, ValueRange workload,
                                const std::array<Value, 3> &workgroupSize,
                                OpBuilder &builder) {
  std::array<Value, 3> result;

  auto constantOne = builder.createOrFold<arith::ConstantIndexOp>(loc, 1);
  if (workload.size() <= 3) {
    // 1-D to 3-D are easy (pad 2 to 0 dimensions) and divide by workgroup
    // size.
    for (int i = 0; i < 3; ++i) {
      // Round up: (workload[i] + workgroup_size - 1) / workgroup_size;
      Value workloadI = i < workload.size() ? workload[i] : constantOne;
      workloadI = builder.createOrFold<arith::SubIOp>(
          loc,
          builder.createOrFold<arith::AddIOp>(loc, workloadI, workgroupSize[i]),
          constantOne);
      result[i] = builder.createOrFold<arith::DivUIOp>(loc, workloadI,
                                                       workgroupSize[i]);
    }
  } else {
    // TODO(#4140): remapping of N-D to 3-D: this is not how you do this!
    Value flatWorkload = constantOne;
    for (auto workloadI : workload) {
      flatWorkload =
          builder.createOrFold<arith::MulIOp>(loc, flatWorkload, workloadI);
    }
    for (int i = 0; i < 3; ++i) {
      // Round up: (workload[i] + workgroup_size - 1) / workgroup_size;
      auto rounded = builder.createOrFold<arith::SubIOp>(
          loc,
          builder.createOrFold<arith::AddIOp>(loc, flatWorkload,
                                              workgroupSize[i]),
          constantOne);
      auto workgroupCountI =
          builder.createOrFold<arith::DivUIOp>(loc, rounded, workgroupSize[i]);
      result[i] = workgroupCountI;

      // Multiply back out and subtract from invocations.
      flatWorkload = builder.createOrFold<arith::SubIOp>(
          loc, flatWorkload,
          builder.createOrFold<arith::MulIOp>(loc, workgroupCountI, rounded));
    }
  }

  return result;
}

static std::array<Value, 3>
calculateWorkgroupCountFromRegion(Location loc, Block *body, Value device,
                                  ValueRange workload, OpBuilder &builder) {
  // TODO(benvanik): replace with region inlining util.
  IRMapping bvm;
  bvm.map(body->getArgument(0), device);
  // For now use the number of args to minimum of number of args used by
  // the body, and number of workload entries. When there is a more explicit
  // propagation of number of workload entries to the `hal.executable.variant`
  // this will be the same by construction.
  unsigned numArgs =
      std::min<unsigned>(body->getNumArguments() - 1, workload.size());
  for (unsigned argNum : llvm::seq<unsigned>(0, numArgs)) {
    bvm.map(body->getArgument(/*device*/ 1 + argNum), workload[argNum]);
  }
  for (Operation &op : body->without_terminator()) {
    builder.clone(op, bvm);
  }
  auto returnOp = cast<IREE::HAL::ReturnOp>(body->getTerminator());
  assert(returnOp.getNumOperands() == 3 && "must return xyz");
  return {
      bvm.lookup(returnOp.getOperands()[0]),
      bvm.lookup(returnOp.getOperands()[1]),
      bvm.lookup(returnOp.getOperands()[2]),
  };
}

// Calculates the workgroup count (x, y, z) for dispatching to the entry point.
// The provided N-dimensional |workload| is the total number of invocations
// required as calculated by the generic workload logic (basically, number of
// output elements in tensors).
std::array<Value, 3> ExecutableExportOp::calculateWorkgroupCount(
    Location loc, Value device, ValueRange workload, OpBuilder &builder) {
  Block *body = getWorkgroupCountBody();
  if (body) {
    return calculateWorkgroupCountFromRegion(loc, body, device, workload,
                                             builder);
  }
  auto workgroupSize = calculateWorkgroupSize(loc, device, workload, builder);
  return calculateWorkloadWorkgroupCount(loc, workload, workgroupSize, builder);
}

// Calculates the workgroup size (x, y, z). These are the dimension numbers
// for a single workgroup.
std::array<Value, 3> ExecutableExportOp::calculateWorkgroupSize(
    Location loc, Value device, ValueRange workload, OpBuilder &builder) {
  // When no workgroup size is specified we just assume [1,1,1].
  // This yields a workgroup count that models the extents of the workload.
  return {
      builder.createOrFold<arith::ConstantIndexOp>(loc, 1),
      builder.createOrFold<arith::ConstantIndexOp>(loc, 1),
      builder.createOrFold<arith::ConstantIndexOp>(loc, 1),
  };
}

//===----------------------------------------------------------------------===//
// hal.executable.variant
//===----------------------------------------------------------------------===//

void ExecutableVariantOp::build(OpBuilder &builder, OperationState &state,
                                StringRef symName,
                                IREE::HAL::ExecutableTargetAttr target) {
  ensureTerminator(*state.addRegion(), builder, state.location);
  state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
                     builder.getStringAttr(symName));
  state.addAttribute("target", target);
}

LogicalResult ExecutableVariantOp::verify() {
  ExecutableVariantOp op = *this;

  auto conditionOps = getOps<IREE::HAL::ExecutableConditionOp>();
  if (llvm::range_size(conditionOps) > 1)
    return op.emitOpError() << "only one condition op is allowed in a variant";

  return success();
}

DenseMap<Attribute, int> ExecutableVariantOp::gatherConstantOrdinals() {
  DenseMap<Attribute, int> map;
  for (auto blockOp : getConstantBlockOps()) {
    int baseCount = map.size();
    for (auto [i, keyAttr] : llvm::enumerate(blockOp.getKeys())) {
      map.try_emplace(keyAttr, baseCount + i);
    }
  }
  return map;
}

Value ExecutableVariantOp::createConditionOp(OpBuilder &builder) {
  assert(!getConditionOp() && "condition op already exists");

  builder.setInsertionPointToStart(&getRegion().front());
  auto conditionOp =
      IREE::HAL::ExecutableConditionOp::create(builder, getLoc());
  Block *entryPoint = conditionOp.addEntryBlock();
  Value device = entryPoint->getArgument(0);

  builder.setInsertionPointToStart(entryPoint);
  return device;
}

Value ExecutableVariantOp::buildCondition(Value device, OpBuilder &builder) {
  // Base case dependent on target information.
  Value selected = IREE::HAL::DeviceQueryOp::createI1(
      getLoc(), device, "hal.executable.format",
      getTarget().getFormat().getValue(), builder);

  // Factor in variant condition region, if any.
  auto conditionOp = getConditionOp();
  if (conditionOp) {
    auto regionOp = scf::ExecuteRegionOp::create(builder, conditionOp.getLoc(),
                                                 builder.getI1Type());

    IRMapping mapper;
    mapper.map(conditionOp.getRegion().getArgument(0), device);
    conditionOp.getRegion().cloneInto(&regionOp.getRegion(), mapper);
    for (auto returnOp :
         llvm::make_early_inc_range(regionOp.getOps<IREE::HAL::ReturnOp>())) {
      OpBuilder builder(returnOp);
      scf::YieldOp::create(builder, returnOp.getLoc(), returnOp.getOperands());
      returnOp.erase();
    }

    selected = arith::AndIOp::create(builder, getLoc(), selected,
                                     regionOp.getResult(0));
  }

  return selected;
}

//===----------------------------------------------------------------------===//
// hal.executable.condition
//===----------------------------------------------------------------------===//

LogicalResult ExecutableConditionOp::verify() {
  ExecutableConditionOp op = *this;
  return verifyTargetConditionRegion(op, op.getBody());
}

void ExecutableConditionOp::build(OpBuilder &builder, OperationState &result,
                                  ArrayRef<NamedAttribute> attrs) {
  result.addAttribute(
      "function_type",
      TypeAttr::get(getTargetConditionRegionType(builder.getContext())));
  result.addRegion();
  result.attributes.append(attrs.begin(), attrs.end());
}

ParseResult ExecutableConditionOp::parse(OpAsmParser &parser,
                                         OperationState &result) {
  if (parseTargetConditionRegion(parser, *result.addRegion()))
    return failure();
  result.addAttribute(
      "function_type",
      TypeAttr::get(getTargetConditionRegionType(parser.getContext())));
  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
    return failure();
  return success();
}

void ExecutableConditionOp::print(OpAsmPrinter &p) {
  Operation *op = getOperation();
  printTargetConditionRegion(p, op, getBody());
  p.printOptionalAttrDictWithKeyword(op->getAttrs(),
                                     /*elidedAttrs=*/{"function_type"});
}

Block *ExecutableConditionOp::addEntryBlock() {
  assert(empty() && "function already has an entry block");
  auto *entry = new Block();
  auto argTypes = getArgumentTypes();
  SmallVector<Location> argLocs(argTypes.size(), getLoc());
  entry->addArguments(argTypes, argLocs);
  push_back(entry);
  return entry;
}

Block *ExecutableConditionOp::addBlock() {
  assert(!empty() && "function should at least have an entry block");
  push_back(new Block());
  return &back();
}

//===----------------------------------------------------------------------===//
// hal.executable.constant.block
//===----------------------------------------------------------------------===//

ParseResult ExecutableConstantBlockOp::parse(OpAsmParser &parser,
                                             OperationState &result) {
  auto &builder = parser.getBuilder();

  // Parse the function signature.
  SmallVector<OpAsmParser::Argument> entryArgs;
  bool isVariadic = false;
  SmallVector<DictionaryAttr> resultAttrs;
  SmallVector<Type> resultTypes;
  if (mlir::function_interface_impl::parseFunctionSignatureWithArguments(
          parser, /*allowVariadic=*/false, entryArgs, isVariadic, resultTypes,
          resultAttrs)) {
    return failure();
  }
  SmallVector<Type> argTypes;
  for (auto &arg : entryArgs)
    argTypes.push_back(arg.type);
  auto fnType = builder.getFunctionType(argTypes, resultTypes);
  result.addAttribute(getFunctionTypeAttrName(result.name),
                      TypeAttr::get(fnType));

  // Parse the keys used for each yielded constant value.
  // There must be one key per result. Note that we support omitted parens when
  // only one result is present.
  SmallVector<Attribute> keyAttrs;
  if (failed(parser.parseKeyword("as")))
    return failure();
  if (resultTypes.size() == 1) {
    std::string key;
    if (failed(parser.parseString(&key)))
      return failure();
    keyAttrs.push_back(builder.getStringAttr(key));
  } else {
    if (failed(parser.parseCommaSeparatedList(
            AsmParser::Delimiter::OptionalParen,
            [&]() {
              std::string key;
              if (failed(parser.parseString(&key)))
                return failure();
              keyAttrs.push_back(builder.getStringAttr(key));
              return success();
            },
            "containing a 1:1 list of keys per yielded value"))) {
      return failure();
    }
  }
  result.addAttribute("keys", builder.getArrayAttr(keyAttrs));

  // If function attributes are present, parse them.
  NamedAttrList parsedAttributes;
  if (parser.parseOptionalAttrDictWithKeyword(parsedAttributes)) {
    return failure();
  }
  result.attributes.append(parsedAttributes);

  // Add the attributes to the function arguments.
  assert(resultAttrs.size() == resultTypes.size());
  mlir::call_interface_impl::addArgAndResultAttrs(
      builder, result, entryArgs, resultAttrs, getArgAttrsAttrName(result.name),
      getResAttrsAttrName(result.name));

  // Parse the optional function body. The printer will not print the body if
  // its empty, so disallow parsing of empty body in the parser.
  auto *body = result.addRegion();
  SMLoc loc = parser.getCurrentLocation();
  if (failed(parser.parseRegion(*body, entryArgs,
                                /*enableNameShadowing=*/false))) {
    return failure();
  }
  // Function body was parsed, make sure its not empty.
  if (body->empty()) {
    return parser.emitError(loc, "expected non-empty function body");
  }

  return success();
}

void ExecutableConstantBlockOp::print(OpAsmPrinter &p) {
  Operation *op = getOperation();
  ArrayRef<Type> argTypes = getArgumentTypes();
  ArrayRef<Type> resultTypes = getResultTypes();
  mlir::function_interface_impl::printFunctionSignature(
      p, cast<mlir::FunctionOpInterface>(op), argTypes, /*isVariadic=*/false,
      resultTypes);
  p << " as ";
  if (resultTypes.size() != 1)
    p << '(';
  llvm::interleaveComma(getKeys().getValue(), p,
                        [&](Attribute attr) { p << attr; });
  if (resultTypes.size() != 1)
    p << ')';
  mlir::function_interface_impl::printFunctionAttributes(
      p, op, {getFunctionTypeAttrName(), getKeysAttrName()});
  p << " ";
  p.printRegion(getBody(), /*printEntryBlockArgs=*/false,
                /*printBlockTerminators=*/true);
}

LogicalResult ExecutableConstantBlockOp::verify() {
  ExecutableConstantBlockOp op = *this;

  // Verify the function takes either nothing or a device.
  auto argTypes = op.getArgumentTypes();
  if (!argTypes.empty() &&
      (argTypes.size() > 1 || !llvm::isa<IREE::HAL::DeviceType>(argTypes[0]))) {
    return op->emitOpError()
           << "initializer must take a !hal.device or nothing";
  }

  // Verify the return types are all i32 (today).
  for (auto resultType : llvm::enumerate(op.getResultTypes())) {
    if (!resultType.value().isInteger(32)) {
      return op->emitOpError()
             << "initializer must return only i32 values (result "
             << resultType.index() << " is " << resultType.value() << ")";
    }
  }

  // Verify there's a key for every result.
  if (op.getNumResults() != op.getKeys().size()) {
    return op->emitOpError() << "must have one key for every result";
  }

  return success();
}

//===----------------------------------------------------------------------===//
// hal.executable.binary
//===----------------------------------------------------------------------===//

void ExecutableBinaryOp::build(OpBuilder &builder, OperationState &state,
                               StringRef symName, StringRef format,
                               std::vector<uint8_t> data) {
  state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
                     builder.getStringAttr(symName));
  state.addAttribute("format", builder.getStringAttr(format));
  state.addAttribute("data",
                     DenseIntElementsAttr::get(
                         VectorType::get({static_cast<int64_t>(data.size())},
                                         builder.getIntegerType(8)),
                         data));
}

void ExecutableBinaryOp::build(OpBuilder &builder, OperationState &state,
                               StringRef symName, StringAttr format,
                               DenseIntElementsAttr data) {
  state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
                     builder.getStringAttr(symName));
  state.addAttribute("format", format);
  state.addAttribute("data", data);
}

//===----------------------------------------------------------------------===//
// hal.executable.create
//===----------------------------------------------------------------------===//

void ExecutableCreateOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {
  // TODO(benvanik): name after sanitized symbol.
  setNameFn(getResult(), StringRef("executable"));
}

//===----------------------------------------------------------------------===//
// hal.executable.lookup
//===----------------------------------------------------------------------===//

void ExecutableLookupOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {
  // TODO(benvanik): name after sanitized symbol.
  setNameFn(getResult(), "exe");
}

//===----------------------------------------------------------------------===//
// hal.executable.export.ordinal
//===----------------------------------------------------------------------===//

void ExecutableExportOrdinalOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {
  // TODO(benvanik): name after sanitized symbol.
  setNameFn(getResult(), "ordinal");
}

//===----------------------------------------------------------------------===//
// hal.interface.constant.load
//===----------------------------------------------------------------------===//

LogicalResult InterfaceConstantLoadOp::verify() {
  InterfaceConstantLoadOp op = *this;
  auto layoutAttr = op.getLayout();
  if (op.getOrdinal().getZExtValue() >= layoutAttr.getConstants()) {
    return op.emitOpError("push constant ordinal out of bounds");
  }
  return success();
}

//===----------------------------------------------------------------------===//
// hal.interface.binding.subspan
//===----------------------------------------------------------------------===//

void InterfaceBindingSubspanOp::build(OpBuilder &builder,
                                      OperationState &result, Type resultType,
                                      IREE::HAL::PipelineLayoutAttr layout,
                                      APInt binding, Value byte_offset,
                                      ValueRange dynamic_dims,
                                      IntegerAttr alignment,
                                      std::optional<DescriptorFlags> flags) {
  IREE::HAL::DescriptorFlagsAttr descriptorAttr;
  if (flags.has_value()) {
    descriptorAttr = IREE::HAL::DescriptorFlagsAttr::get(builder.getContext(),
                                                         flags.value());
  }
  build(builder, result, resultType, layout, binding, byte_offset, dynamic_dims,
        alignment, descriptorAttr);
}

LogicalResult InterfaceBindingSubspanOp::verify() {
  InterfaceBindingSubspanOp op = *this;
  if (ShapedType shapedType = llvm::dyn_cast<ShapedType>(op.getType())) {
    if (shapedType.getNumDynamicDims() != op.getDynamicDims().size()) {
      return op.emitOpError("result type ")
             << op.getType() << " has " << shapedType.getNumDynamicDims()
             << " dynamic dimensions but " << op.getDynamicDims().size()
             << " associated dimension SSA values";
    }
  }
  uint64_t binding = op.getBinding().getZExtValue();
  if (binding >= op.getLayout().getBindings().size()) {
    return op.emitOpError("binding ordinal ")
           << binding << " out of bounds in layout " << op.getLayout();
  }
  return success();
}

IREE::HAL::PipelineBindingAttr
InterfaceBindingSubspanOp::getPipelineBindingAttr() {
  return getLayout().getBinding(getBinding());
}

IREE::HAL::DescriptorType InterfaceBindingSubspanOp::getDescriptorType() {
  auto bindingAttr = getPipelineBindingAttr();
  return bindingAttr.getType();
}

llvm::MaybeAlign InterfaceBindingSubspanOp::getBaseAlignment() {
  if (auto baseAlignmentInt = getAlignment()) {
    return llvm::MaybeAlign(baseAlignmentInt.value().getZExtValue());
  }
  return std::nullopt;
}

llvm::Align InterfaceBindingSubspanOp::calculateAlignment() {
  // If we can't calculate an alignment we fall back to the natural alignment of
  // the element type (for example, a memref<?xi32> is known to be at least
  // 4-byte aligned).
  llvm::Align naturalAlignment(1);
  auto resultType = getType();
  if (auto shapedType = llvm::dyn_cast<ShapedType>(resultType)) {
    naturalAlignment = llvm::Align(
        IREE::Util::getRoundedElementByteWidth(shapedType.getElementType()));
  }

  // If the binding has no assigned alignment we fall back to natural alignment.
  auto baseAlignment = getBaseAlignment();
  if (!baseAlignment)
    return naturalAlignment;

  // If there's no offset specified then we can use the binding alignment
  // directly.
  if (!getByteOffset())
    return baseAlignment.value();

  // Try to get the alignment of the byte offset. If it's a constant then we can
  // find a common alignment between it and the base and otherwise we need to
  // try to infer the alignment from the IR - otherwise we fall back.
  auto offsetOrAlignment = lookupOffsetOrAlignment(getByteOffset());
  if (!offsetOrAlignment.has_value())
    return naturalAlignment;

  // Compute the common alignment between that of the binding base and that of
  // the byte offset.
  return llvm::commonAlignment(baseAlignment.value(),
                               offsetOrAlignment.value());
}

LogicalResult InterfaceBindingSubspanOp::reifyResultShapes(
    OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
  auto resultShapedType = dyn_cast<ShapedType>(getResult().getType());
  if (!resultShapedType) {
    return failure();
  }
  SmallVector<OpFoldResult> resultShape = mlir::getMixedValues(
      resultShapedType.getShape(), getDynamicDims(), builder);
  reifiedReturnShapes.emplace_back(std::move(resultShape));
  return success();
}

//===----------------------------------------------------------------------===//
// hal.interface.workgroup.*
//===----------------------------------------------------------------------===//

static void getAsmResultNamesForInterfaceWorkgroupOp(
    StringRef prefix, const APInt &dimension, Value result,
    function_ref<void(Value, StringRef)> setNameFn) {
  switch (dimension.getZExtValue()) {
  case 0:
    setNameFn(result, (prefix + "x").str());
    return;
  case 1:
    setNameFn(result, (prefix + "y").str());
    return;
  case 2:
    setNameFn(result, (prefix + "z").str());
    return;
  }
}

// Minimum is the smallest possible result we could get. It's 0 for ID-like
// operations and 1 for count-like ones.
static void setResultRangesForInterfaceWorkgroupOp(
    Value result, const std::optional<APInt> &upperBound,
    SetIntRangeFn setResultRanges, int64_t minimum) {
  unsigned width = ConstantIntRanges::getStorageBitwidth(result.getType());
  if (!upperBound.has_value()) {
    setResultRanges(
        result, ConstantIntRanges::fromSigned(APInt(width, minimum),
                                              APInt::getSignedMaxValue(width)));
    return;
  }
  setResultRanges(result,
                  ConstantIntRanges::fromUnsigned(APInt(width, minimum),
                                                  *upperBound + minimum - 1));
}

void InterfaceWorkgroupIDOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {
  getAsmResultNamesForInterfaceWorkgroupOp("workgroup_id_", getDimension(),
                                           getResult(), setNameFn);
}

void InterfaceWorkgroupIDOp::inferResultRanges(
    ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
  setResultRangesForInterfaceWorkgroupOp(getResult(), getUpperBound(),
                                         setResultRanges, /*minimum=*/0);
}

void InterfaceWorkgroupCountOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {
  getAsmResultNamesForInterfaceWorkgroupOp("workgroup_count_", getDimension(),
                                           getResult(), setNameFn);
}

void InterfaceWorkgroupCountOp::inferResultRanges(
    ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
  setResultRangesForInterfaceWorkgroupOp(getResult(), getUpperBound(),
                                         setResultRanges, /*minimum=*/1);
}

void InterfaceWorkgroupSizeOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {
  getAsmResultNamesForInterfaceWorkgroupOp("workgroup_size_", getDimension(),
                                           getResult(), setNameFn);
}

void InterfaceWorkgroupSizeOp::inferResultRanges(
    ArrayRef<ConstantIntRanges> argRanges, SetIntRangeFn setResultRanges) {
  setResultRangesForInterfaceWorkgroupOp(getResult(), getUpperBound(),
                                         setResultRanges, /*minimum=*/1);
}

//===----------------------------------------------------------------------===//
// hal.fence.*
//===----------------------------------------------------------------------===//

void FenceCreateOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {
  setNameFn(getResult(), "fence");
}

void FenceJoinOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {
  setNameFn(getResult(), "fence");
}

void FenceAwaitOp::getAsmResultNames(
    function_ref<void(Value, StringRef)> setNameFn) {
  setNameFn(getStatus(), "status");
}

} // namespace mlir::iree_compiler::IREE::HAL

//===----------------------------------------------------------------------===//
// TableGen definitions (intentionally last)
//===----------------------------------------------------------------------===//

#define GET_OP_CLASSES
#include "iree/compiler/Dialect/HAL/IR/HALOps.cpp.inc"
